{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "[Learn the Basics](intro.html) ||\n",
        "**Quickstart** ||\n",
        "[Tensors](tensorqs_tutorial.html) ||\n",
        "[Datasets & DataLoaders](data_tutorial.html) ||\n",
        "[Transforms](transforms_tutorial.html) ||\n",
        "[Build Model](buildmodel_tutorial.html) ||\n",
        "[Autograd](autogradqs_tutorial.html) ||\n",
        "[Optimization](optimization_tutorial.html) ||\n",
        "[Save & Load Model](saveloadrun_tutorial.html)\n",
        "\n",
        "# Quickstart\n",
        "\n",
        "This section runs through the API for common tasks in machine learning. Refer to the links in each section to dive deeper.\n",
        "\n",
        "## Working with data\n",
        "\n",
        "PyTorch has two [primitives to work with data](https://pytorch.org/docs/stable/data.html):\n",
        "`torch.utils.data.DataLoader` and `torch.utils.data.Dataset`.\n",
        "`Dataset` stores the samples and their corresponding labels, and `DataLoader` wraps an iterable around\n",
        "the `Dataset`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch import nn, Tensor\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets\n",
        "from torchvision.transforms import ToTensor\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "PyTorch offers domain-specific libraries such as [TorchText](https://pytorch.org/text/stable/index.html),\n",
        "[TorchVision](https://pytorch.org/vision/stable/index.html), and [TorchAudio](https://pytorch.org/audio/stable/index.html),\n",
        "all of which include datasets. For this tutorial, we will be using a TorchVision dataset.\n",
        "\n",
        "The `torchvision.datasets` module contains `Dataset` objects for many real-world vision data like\n",
        "CIFAR, COCO ([full list here](https://pytorch.org/vision/stable/datasets.html)). In this tutorial, we\n",
        "use the FashionMNIST dataset. Every TorchVision `Dataset` includes two arguments: `transform` and\n",
        "`target_transform` to modify the samples and labels respectively.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "#\n",
        "ROOT = \"E:/sfxData/DeepLearning\"\n",
        "\n",
        "# Download training data from open datasets.\n",
        "training_data = datasets.FashionMNIST(\n",
        "    root=ROOT,\n",
        "    train=True,\n",
        "    download=True,\n",
        "    transform=ToTensor(),\n",
        ")\n",
        "\n",
        "# Download test data from open datasets.\n",
        "test_data = datasets.FashionMNIST(\n",
        "    root=ROOT,\n",
        "    train=False,\n",
        "    download=True,\n",
        "    transform=ToTensor(),\n",
        ")\n",
        "\n",
        "# print(training_data.data.shape)\n",
        "# print(training_data.targets.shape)\n",
        "\n",
        "print(test_data)\n",
        "print(test_data.data[0])\n",
        "print(test_data.targets)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We pass the `Dataset` as an argument to `DataLoader`. This wraps an iterable over our dataset, and supports\n",
        "automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 64, i.e. each element\n",
        "in the dataloader iterable will return a batch of 64 features and labels.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "batch_size = 64\n",
        "\n",
        "# Create data loaders.\n",
        "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
        "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n",
        "\n",
        "for X, y in test_dataloader:\n",
        "    print(f\"Shape of X [N, C, H, W]: {X.shape}\")  # [64,1,28,28]\n",
        "    print(f\"Shape of y: {y.shape} {y.dtype}\")     # [64]\n",
        "    break\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Read more about [loading data in PyTorch](data_tutorial.html).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "---\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Creating Models\n",
        "\n",
        "To define a neural network in PyTorch, we create a class that inherits\n",
        "from [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). We define the layers of the network\n",
        "in the `__init__` function and specify how data will pass through the network in the `forward` function. To accelerate(加速)\n",
        "operations in the neural network, we move it to the GPU if available.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "# Get cpu or gpu device for training.\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "print(f\"Using {device} device\")\n",
        "\n",
        "# Define model\n",
        "\n",
        "\n",
        "class NeuralNetwork(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.flatten = nn.Flatten()\n",
        "        self.linear_relu_stack = nn.Sequential(\n",
        "            nn.Linear(28*28, 512),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(512, 512),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(512, 10)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.flatten(x)\n",
        "        logits = self.linear_relu_stack(x)\n",
        "        return logits\n",
        "\n",
        "\n",
        "# model = NeuralNetwork().to(device)\n",
        "model = NeuralNetwork()\n",
        "model.to(device)\n",
        "\n",
        "print(model)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Read more about [building neural networks in PyTorch](buildmodel_tutorial.html).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "---\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Optimizing the Model Parameters\n",
        "\n",
        "To train a model, we need a [loss function](https://pytorch.org/docs/stable/nn.html#loss-functions)\n",
        "and an [optimizer](https://pytorch.org/docs/stable/optim.html).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "loss_fn = nn.CrossEntropyLoss()\n",
        "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and\n",
        "backpropagates the prediction error to adjust the model's parameters.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "def train(dataloader, model, loss_fn, optimizer):\n",
        "    size = len(dataloader.dataset)\n",
        "    model.train()\n",
        "    for batch, (X, y) in enumerate(dataloader):\n",
        "        X, y = X.to(device), y.to(device)\n",
        "\n",
        "        # Compute prediction error\n",
        "        pred = model(X)\n",
        "        loss = loss_fn(pred, y)\n",
        "\n",
        "        # Backpropagation\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        if batch % 100 == 0:\n",
        "            loss, current = loss.item(), batch * len(X)\n",
        "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We also check the model's performance against the test dataset to ensure it is learning.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "def test(dataloader, model, loss_fn):\n",
        "    size = len(dataloader.dataset)\n",
        "    num_batches = len(dataloader)\n",
        "    model.eval()\n",
        "    test_loss, correct = 0, 0\n",
        "    with torch.no_grad():\n",
        "        for X, y in dataloader:\n",
        "            X, y = X.to(device), y.to(device)\n",
        "            pred = model(X)\n",
        "            test_loss += loss_fn(pred, y).item()\n",
        "            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
        "    test_loss /= num_batches\n",
        "    correct /= size\n",
        "    print(\n",
        "        f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The training process is conducted over several iterations (_epochs_). During each epoch, the model learns\n",
        "parameters to make better predictions. We print the model's accuracy and loss at each epoch; we'd like to see the\n",
        "accuracy increase and the loss decrease with every epoch.\n",
        "\n",
        "训练过程通过多次迭代（_epochs_）进行。在每个时期，模型学习\n",
        "参数以做出更好的预测。我们打印模型在每个时期的准确性和损失;我们希望看到\n",
        "精度随着每个时期的增加而增加，损失减少。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "epochs = 5\n",
        "for t in range(epochs):\n",
        "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
        "    train(train_dataloader, model, loss_fn, optimizer)\n",
        "    test(test_dataloader, model, loss_fn)\n",
        "print(\"Done!\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Read more about [Training your model](optimization_tutorial.html).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "---\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Saving Models\n",
        "\n",
        "A common way to save a model is to serialize the internal state dictionary (containing the model parameters).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "torch.save(model.state_dict(), \"model.pth\")\n",
        "print(\"Saved PyTorch Model State to model.pth\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Loading Models\n",
        "\n",
        "The process for loading a model includes re-creating the model structure and loading\n",
        "the state dictionary into it.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "model = NeuralNetwork()\n",
        "model.load_state_dict(torch.load(\"model.pth\"))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This model can now be used to make predictions.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "dotnet_interactive": {
          "language": "csharp"
        }
      },
      "outputs": [],
      "source": [
        "classes = [\n",
        "    \"T-shirt/top\",\n",
        "    \"Trouser\",\n",
        "    \"Pullover\",\n",
        "    \"Dress\",\n",
        "    \"Coat\",\n",
        "    \"Sandal\",\n",
        "    \"Shirt\",\n",
        "    \"Sneaker\",\n",
        "    \"Bag\",\n",
        "    \"Ankle boot\",\n",
        "]\n",
        "\n",
        "model.eval()\n",
        "x, y = test_data[1][0], test_data[1][1]\n",
        "\n",
        "with torch.no_grad():\n",
        "    pred = model(x)\n",
        "    predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
        "    print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Read more about [Saving & Loading your model](saveloadrun_tutorial.html).\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.9.9 64-bit",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.9 (tags/v3.9.9:ccb0e6a, Nov 15 2021, 18:08:50) [MSC v.1929 64 bit (AMD64)]"
    },
    "vscode": {
      "interpreter": {
        "hash": "cf18841ace8313d0bc088ca146c17a6c0040e82121d5cb75c0ea07172309253d"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
